
import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class FedQearly_gen:
    def __init__(self, mdp, c1, c2, c3, beta, total_episodes, num_agents):
        self.mdp = mdp
        self.total_episodes = total_episodes
        self.num_agents = num_agents
        self.trigger_times = 0
        self.comm_episode_collection = []
        self.Nswitch = 0
        self.Nswitch1 = 0
        self.n_switch = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))

        self.regret = []
        self.globalcost = []
        self.cost = []

        self.mdp = mdp
        self.c1 = c1
        self.c2 = c2
        self.c3 = c3

        self.V_func = np.zeros((self.mdp.H + 1, self.mdp.S),dtype = np.float32)
        self.VL = np.zeros((self.mdp.H + 1, self.mdp.S), dtype=np.float32)
        self.V_ref_func = np.zeros((self.mdp.H + 1, self.mdp.S),dtype = np.float32)

        self.V_ref_trigger = np.zeros((self.mdp.H, self.mdp.S), dtype = np.int32)

        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)        
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i
        self.QU = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)        
        for i in range(self.mdp.H):
            self.QU[i,:,:] = self.mdp.H - i
        self.QR = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)        
        for i in range(self.mdp.H):
            self.QR[i,:,:] = self.mdp.H - i
        self.QL = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.n = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        
        self.Vref_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vref2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vadv_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vadv2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.beta = beta
        
        self.agent_N = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.agent_V_nextsum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_VL_nextsum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vref_nextsum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vref2_nextsum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

    def run_episode(self, agent_id):
        # Get the policy (actions for all states and steps)
        #V_func[h,s]
        event_triggered = False
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            

            # Increment visit count for the current state-action pair
            self.agent_N[agent_id, step, state, action] += 1
            self.agent_V_nextsum[agent_id, step, state, action] += self.V_func[step+1, next_state]
            self.agent_VL_nextsum[agent_id, step, state, action] += self.VL[step+1, next_state]
            self.agent_Vref_nextsum[agent_id, step, state, action] += self.V_ref_func[step+1, next_state]
            self.agent_Vref2_nextsum[agent_id, step, state, action] += (self.V_ref_func[step+1, next_state])**2

            # Store the received reward
            rewards[step, state, action] = reward
            # Check if the event-triggered condition is met

            flag = self.check_event_triggered(agent_id, step, state, action)
            if flag:
                event_triggered = True
            state = next_state
        return rewards, event_triggered, state_init

    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions


    def check_event_triggered(self, agent_id, step, state, action):
        # Calculate the threshold for triggering the event
        tilde_C = 1.0 / (self.mdp.H * (self.mdp.H + 1))
        global_visits = self.N[step, state, action]
        threshold = max(1, int(np.floor((tilde_C / self.num_agents) * global_visits)))

        # Check if the visit count exceeds the threshold
        return self.agent_N[agent_id, step, state, action] >= threshold

    def aggregate_data(self, policy_k, rewards):
        H, M = self.mdp.H, self.num_agents
        i_0 = 2 * M * H * (H + 1)
        for h in range(H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    #print(policy_k[h, s])
                    if a != np.argmax(policy_k[h, s]) or self.agent_N[:, h, s, a].sum() == 0:
                        # No update required, retain previous Q-values
                        continue
                    else:
                        # Calculate aggregated values
                        N_h_k = self.N[h, s, a]
                        n_h_k = self.agent_N[:, h, s, a].sum()
                        eref_old = max(self.Vref2_sum[h, s, a] - (self.Vref_sum[h, s, a])**2, 0)
                        eadv_old = max(H*(self.Vadv2_sum[h,s,a]- (self.Vadv_sum[h,s,a])**2),0)
                        if N_h_k == 0:
                            beta_h_k = 0
                        else:
                            beta_h_k = self.c2 * (np.sqrt(eref_old/N_h_k)+np.sqrt(eadv_old/N_h_k))
                        self.Vref_sum[h,s,a] = (N_h_k*self.Vref_sum[h,s,a]+self.agent_Vref_nextsum[:,h,s,a].sum())/(N_h_k+n_h_k)
                        self.Vref2_sum[h,s,a] = (N_h_k*self.Vref2_sum[h,s,a]+self.agent_Vref2_nextsum[:,h,s,a].sum())/(N_h_k+n_h_k)
                        eref_new = max(self.Vref2_sum[h, s, a] - (self.Vref_sum[h, s, a])**2, 0)  
                        t00 = N_h_k
                        if N_h_k < i_0:
                            for ag_id in range(self.num_agents):
                                if self.agent_N[ag_id, h, s, a] > 0:
                                    t00 = t00 + 1
                                    step_size = (H + 1) / (H + t00)
                                    ucb_bonus = self.c1 * (H - h - 1) * np.sqrt(H / t00)
                                    ref_bonus_small = beta_h_k +self.c3 * (H-h-1)**2 / t00
                                    vadv = self.agent_V_nextsum[ag_id,h,s,a] - self.agent_Vref_nextsum[ag_id,h,s,a]
                                    self.Vadv_sum[h,s,a] = (1 - step_size)*self.Vadv_sum[h,s,a]+ step_size * vadv 
                                    self.Vadv2_sum[h,s,a] = (1 - step_size)*self.Vadv2_sum[h,s,a]+ step_size * vadv**2
                                    self.QU[h, s, a] = (1 - step_size) * self.QU[h, s, a]  + step_size * (rewards[h, s, a] +self.agent_V_nextsum[ag_id,h,s,a]+ucb_bonus)
                                    self.QL[h, s, a] = (1 - step_size) * self.QL[h, s, a]  + step_size * (rewards[h, s, a] +self.agent_VL_nextsum[ag_id,h,s,a]-ucb_bonus)
                                    self.QR[h, s, a] = (1 - step_size) * self.QR[h, s, a] + step_size * (rewards[h, s, a] + self.Vref_sum[h, s, a]+vadv + ref_bonus_small)
                            
                            N_h_k_new = N_h_k+n_h_k
                            eadv_new = max(H*(self.Vadv2_sum[h,s,a]- (self.Vadv_sum[h,s,a])**2),0)
                            beta_h_k_new = self.c2 * (np.sqrt(eref_new/N_h_k_new)+np.sqrt(eadv_new/N_h_k_new))
                            self.QR[h, s, a] = self.QR[h, s, a] + beta_h_k_new - beta_h_k
                            self.global_Q[h, s, a] = min([self.QU[h, s, a], self.QR[h, s, a], self.global_Q[h,s,a]])                    

                        else:
                            alpha_agg_side = 1.0
                            bonus = 0
                            ref_bonus = 0
                            for i in range(n_h_k):
                                t00 = t00 + 1
                                step_size = (H + 1) / (H + t00)
                                alpha_agg_side = alpha_agg_side*(1 - step_size)
                                ucb_bonus = self.c1 * (H - h - 1) * np.sqrt(H / t00)
                                ref_bonus_large = beta_h_k+self.c3 * (H-h-1)**2 / t00
                                bonus = (1-step_size)*bonus + step_size*ucb_bonus
                                ref_bonus = (1-step_size)*ref_bonus + step_size*ref_bonus_large

                            vadv = self.agent_V_nextsum[:,h,s,a].sum()-self.agent_Vref_nextsum[:,h,s,a].sum()
                            self.Vadv_sum[h,s,a] = alpha_agg_side*self.Vadv_sum[h,s,a]+(1 - alpha_agg_side)*vadv/n_h_k
                            self.Vadv2_sum[h,s,a] = alpha_agg_side*self.Vadv2_sum[h,s,a]+(1 - alpha_agg_side)*vadv**2/n_h_k
                            N_h_k_new = N_h_k+n_h_k
                            eadv_new = max(H*(self.Vadv2_sum[h,s,a]- (self.Vadv_sum[h,s,a])**2),0)
                            beta_h_k_new = self.c2 * (np.sqrt(eref_new/N_h_k_new)+np.sqrt(eadv_new/N_h_k_new))
                            self.QU[h, s, a] = alpha_agg_side * self.QU[h, s, a]  + (1 - alpha_agg_side) * (rewards[h, s, a]+self.agent_V_nextsum[:,h,s,a].sum()/n_h_k) +bonus
                            self.QL[h, s, a] = alpha_agg_side * self.QL[h, s, a]  + (1 - alpha_agg_side) * (rewards[h, s, a]+self.agent_VL_nextsum[:,h,s,a].sum()/n_h_k) -bonus
                            self.QR[h, s, a] = alpha_agg_side * self.QR[h, s, a] + \
                                                (1 - alpha_agg_side) * (rewards[h, s, a] + self.Vref_sum[h, s, a]+vadv/n_h_k) + ref_bonus +beta_h_k_new-beta_h_k
                            self.global_Q[h, s, a] = min([self.QU[h, s, a], self.QR[h, s, a], self.global_Q[h,s,a]])
                        
        self.N += self.agent_N.sum(axis=0)
        self.agent_N.fill(0)
        self.agent_V_nextsum.fill(0)
        self.agent_VL_nextsum.fill(0)
        self.agent_Vref_nextsum.fill(0)
        self.agent_Vref2_nextsum.fill(0)

    def update_reference(self, h, s):
        if self.V_ref_trigger[h,s] == 1:
            return
        else:
            if self.V_func[h,s] - self.VL[h,s] < self.beta:
                self.V_ref_trigger[h,s] = 1
                self.V_ref_func[h,s] = self.V_func[h,s]
            else:
                self.V_ref_func[h,s] = self.V_func[h,s]

    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()
        # Event-triggered termination flag
        event_triggered = False
        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.V_func[h,s] = max(self.global_Q[h, s, :])
                self.V_ref_func[h,s] = self.V_func[h,s]
        actions_policy = self.choose_action()
        self.n_switch = actions_policy

        for episode in range(self.total_episodes):
            # Run one episode for each agent
            value = self.mdp.value_gen(actions_policy)
            for agent_id in range(self.num_agents):
                agent_reward, agent_event_triggered, state_init = self.run_episode(agent_id)
                self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
                self.regret.append(self.regret_cum)          

                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        a = np.argmax(actions_policy[h, s])
                        if rewards[h, s, a] == 0:
                            rewards[h, s, a] = agent_reward[h,s,a]

                if agent_event_triggered:
                    event_triggered = True
            
            # Calculate regret
            if event_triggered:
                self.trigger_times += 1
                self.comm_episode_collection.append(episode)
                self.aggregate_data(actions_policy, rewards)
                event_triggered = False
                actions_policy = self.choose_action()
                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        self.V_func[h, s] = max(self.global_Q[h, s, :])
                        self.VL[h, s] = max(max(self.QL[h, s, :]), self.VL[h, s])
                        self.update_reference(h, s)
            self.cost.append(self.trigger_times)
            if not np.array_equal(self.n_switch, actions_policy):
                self.Nswitch += 1            
            self.globalcost.append(self.Nswitch)
            self.n_switch = actions_policy   
        return best_Q, self.global_Q